"""
@Project ：microproduct 
@File ：ImageHandle.py
@Function ：实现对待处理SAR数据的读取、格式标准化和处理完后保存文件功能
@Author ：LMM
@Date ：2021/10/19 14:39
@Version ：1.0.0
"""
import os
from PIL import Image
from osgeo import gdal
from osgeo import osr
import numpy as np
from PIL import Image
import cv2
import logging
import math


class ImageHandler:
    """
    影像读取、编辑、保存
    """
    def __init__(self):
        pass
    @staticmethod
    def get_dataset(filename):
        """
       :param filename: tif路径
       :return: 图像句柄
       """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        return dataset

    def get_scope(self, filename):
        """
        :param filename: tif路径
        :return: 图像范围
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        im_scope = self.cal_img_scope(dataset)
        del dataset
        return im_scope

    @staticmethod
    def get_projection(filename):
        """
         :param filename: tif路径
         :return: 地图投影信息
         """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        im_proj = dataset.GetProjection()
        del dataset
        return im_proj

    @staticmethod
    def get_geotransform(filename):
        """
        :param filename: tif路径
        :return: 从图像坐标空间（行、列），也称为（像素、线）到地理参考坐标空间（投影或地理坐标）的仿射变换
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        geotransform = dataset.GetGeoTransform()
        del dataset
        return geotransform

    def get_invgeotransform(filename):
        """
        :param filename: tif路径
        :return: 从地理参考坐标空间（投影或地理坐标）的到图像坐标空间（行、列
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        geotransform = dataset.GetGeoTransform()
        geotransform=gdal.InvGeoTransform(geotransform)
        del dataset
        return geotransform        
    
    @staticmethod
    def get_bands(filename):
        """
        :param filename: tif路径
        :return: 影像的波段数
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        bands = dataset.RasterCount
        del dataset
        return bands

    @staticmethod
    def geo2lonlat(dataset, x, y):
        """
        将投影坐标转为经纬度坐标（具体的投影坐标系由给定数据确定）
        :param dataset: GDAL地理数据
        :param x: 投影坐标x
        :param y: 投影坐标y
        :return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
        """
        prosrs = osr.SpatialReference()
        prosrs.ImportFromWkt(dataset.GetProjection())
        geosrs = prosrs.CloneGeogCS()
        ct = osr.CoordinateTransformation(prosrs, geosrs)
        coords = ct.TransformPoint(x, y)
        return coords[:2]

    
    
    @staticmethod
    def get_band_array(filename, num=1):
        """
        :param filename: tif路径
        :param num: 波段序号
        :return: 对应波段的矩阵数据
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        bands = dataset.GetRasterBand(num)
        array = bands.ReadAsArray(0, 0, bands.XSize, bands.YSize)

        # if 'int' in str(array.dtype):
        #     array[np.where(array == -9999)] = np.inf
        # else:
        #     array[np.where(array < -9000.0)] = np.nan

        del dataset
        return array

    @staticmethod
    def get_data(filename):
        """
        :param filename: tif路径
        :return: 获取所有波段的数据
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        im_width = dataset.RasterXSize
        im_height = dataset.RasterYSize
        im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
        del dataset
        return im_data

    @staticmethod
    def get_all_band_array(filename):
        """
        （大气延迟算法）
        将ERA-5影像所有波段存为一个数组, 波段数在第三维度           get_data（）->（37，8，8）
        :param filename： 影像路径                       get_all_band_array ->（8，8，37）
        :return: 影像数组
        """
        dataset = gdal.Open(filename)
        x_size = dataset.RasterXSize
        y_size = dataset.RasterYSize
        nums = dataset.RasterCount
        array = np.zeros((y_size, x_size, nums), dtype=float)
        if nums == 1:
            bands_0 = dataset.GetRasterBand(1)
            array = bands_0.ReadAsArray(0, 0, x_size, y_size)
        else:
            for i in range(0, nums):
                bands = dataset.GetRasterBand(i+1)
                arr = bands.ReadAsArray(0, 0, x_size, y_size)
                array[:, :, i] = arr
        return array

    @staticmethod
    def get_img_width(filename):
        """
        :param filename: tif路径
        :return: 影像宽度
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        width = dataset.RasterXSize

        del dataset
        return width

    @staticmethod
    def get_img_height(filename):
        """
        :param filename: tif路径
        :return: 影像高度
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None
        height = dataset.RasterYSize
        del dataset
        return height

    @staticmethod
    def read_img(filename):
        """
        影像读取
        :param filename:
        :return:
        """
        gdal.AllRegister()
        img_dataset = gdal.Open(filename)  # 打开文件

        if img_dataset is None:
            msg = 'Could not open ' + filename
            print(msg)
            return None, None, None

        im_proj = img_dataset.GetProjection()  # 地图投影信息
        if im_proj is None:
            return None, None, None
        im_geotrans = img_dataset.GetGeoTransform()  # 仿射矩阵

        im_width = img_dataset.RasterXSize  # 栅格矩阵的行数
        im_height = img_dataset.RasterYSize  # 栅格矩阵的行数
        im_arr = img_dataset.ReadAsArray(0, 0, im_width, im_height)
        del img_dataset
        return im_proj, im_geotrans, im_arr
    
    
    
    
    def cal_img_scope(self, dataset):
        """
        计算影像的地理坐标范围
        根据GDAL的六参数模型将影像图上坐标（行列号）转为投影坐标或地理坐标（根据具体数据的坐标系统转换）
        :param dataset :GDAL地理数据
        :return: list[point_upleft, point_upright, point_downleft, point_downright]
        """
        if dataset is None:
            return None

        img_geotrans = dataset.GetGeoTransform()
        if img_geotrans is None:
            return None

        width = dataset.RasterXSize  # 栅格矩阵的列数
        height = dataset.RasterYSize  # 栅格矩阵的行数

        point_upleft = self.trans_rowcol2geo(img_geotrans, 0, 0)
        point_upright = self.trans_rowcol2geo(img_geotrans, width, 0)
        point_downleft = self.trans_rowcol2geo(img_geotrans, 0, height)
        point_downright = self.trans_rowcol2geo(img_geotrans, width, height)

        return [point_upleft, point_upright, point_downleft, point_downright]

    @staticmethod
    def get_scope_ori_sim(filename):
        """
        计算影像的地理坐标范围
        根据GDAL的六参数模型将影像图上坐标（行列号）转为投影坐标或地理坐标（根据具体数据的坐标系统转换）
        :param dataset :GDAL地理数据
        :return: list[point_upleft, point_upright, point_downleft, point_downright]
        """
        gdal.AllRegister()
        dataset = gdal.Open(filename)
        if dataset is None:
            return None

        width = dataset.RasterXSize  # 栅格矩阵的列数
        height = dataset.RasterYSize  # 栅格矩阵的行数

        band1 = dataset.GetRasterBand(1)
        array1 = band1.ReadAsArray(0, 0, band1.XSize, band1.YSize)

        band2 = dataset.GetRasterBand(2)
        array2 = band2.ReadAsArray(0, 0, band2.XSize, band2.YSize)

        point_upleft = [array1[0, 0], array2[0, 0]]
        point_upright = [array1[0, width-1], array2[0, width-1]]
        point_downleft = [array1[height-1, 0], array2[height-1, 0]]
        point_downright = [array1[height-1, width-1], array2[height-1, width-1]]
        return [point_upleft, point_upright, point_downleft, point_downright]


    @staticmethod
    def trans_rowcol2geo(img_geotrans,img_col, img_row):
        """
        据GDAL的六参数模型仿射矩阵将影像图上坐标（行列号）转为投影坐标或地理坐标（根据具体数据的坐标系统转换）
        :param img_geotrans: 仿射矩阵
        :param img_col:图像纵坐标
        :param img_row:图像横坐标
        :return: [geo_x,geo_y]
        """
        geo_x = img_geotrans[0] + img_geotrans[1] * img_col + img_geotrans[2] * img_row
        geo_y = img_geotrans[3] + img_geotrans[4] * img_col + img_geotrans[5] * img_row
        return [geo_x, geo_y]

    @staticmethod
    def write_era_into_img(filename, im_proj, im_geotrans, im_data):
        """
        影像保存
        :param filename:
        :param im_proj:
        :param im_geotrans:
        :param im_data:
        :return:
        """
        gdal_dtypes = {
            'int8': gdal.GDT_Byte,
            'unit16': gdal.GDT_UInt16,
            'int16': gdal.GDT_Int16,
            'unit32': gdal.GDT_UInt32,
            'int32': gdal.GDT_Int32,
            'float32': gdal.GDT_Float32,
            'float64': gdal.GDT_Float64,
        }
        if not gdal_dtypes.get(im_data.dtype.name, None) is None:
            datatype = gdal_dtypes[im_data.dtype.name]
        else:
            datatype = gdal.GDT_Float32

        # 判读数组维数
        if len(im_data.shape) == 3:
            im_height, im_width, im_bands = im_data.shape    # shape[0] 行数
        else:
            im_bands, (im_height, im_width) = 1, im_data.shape

        # 创建文件
        if os.path.exists(os.path.split(filename)[0]) is False:
            os.makedirs(os.path.split(filename)[0])

        driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有，因为要计算需要多大内存空间
        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影

        if im_bands == 1:
            dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
        else:
            for i in range(im_bands):
                dataset.GetRasterBand(i + 1).WriteArray(im_data[:, :, i])
                # dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
        del dataset

        # 写GeoTiff文件

    @staticmethod
    def write_img(filename, im_proj, im_geotrans, im_data, no_data='null'):
        """
        影像保存
        :param filename: 保存的路径
        :param im_proj:
        :param im_geotrans:
        :param im_data:
        :param no_data: 把无效值设置为 nodata
        :return:
        """

        gdal_dtypes = {
            'int8': gdal.GDT_Byte,
            'unit16': gdal.GDT_UInt16,
            'int16': gdal.GDT_Int16,
            'unit32': gdal.GDT_UInt32,
            'int32': gdal.GDT_Int32,
            'float32': gdal.GDT_Float32,
            'float64': gdal.GDT_Float64,
        }
        if not gdal_dtypes.get(im_data.dtype.name, None) is None:
            datatype = gdal_dtypes[im_data.dtype.name]
        else:
            datatype = gdal.GDT_Float32

        # 判读数组维数
        if len(im_data.shape) == 3:
            im_bands, im_height, im_width = im_data.shape
        else:
            im_bands, (im_height, im_width) = 1, im_data.shape

        # 创建文件
        if os.path.exists(os.path.split(filename)[0]) is False:
            os.makedirs(os.path.split(filename)[0])

        driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有，因为要计算需要多大内存空间
        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数

        dataset.SetProjection(im_proj)  # 写入投影

        if im_bands == 1:
            # outRaster.GetRasterBand(1).WriteArray(array)  # 写入数组数据
            outband = dataset.GetRasterBand(1)
            outband.WriteArray(im_data)
            if no_data != 'null':
                outband.SetNoDataValue(no_data)
            outband.FlushCache()
        else:
            for i in range(im_bands):
                outband = dataset.GetRasterBand(1 + i)
                outband.WriteArray(im_data[i])
                outband.FlushCache()
                # outRaster.GetRasterBand(i + 1).WriteArray(array[i])
        del dataset

        # 写GeoTiff文件

    @staticmethod
    def write_img_rpc(filename, im_proj, im_geotrans, im_data, rpc_dict):
        """
        图像中写入rpc信息
        """
        # 判断栅格数据的数据类型
        if 'int8' in im_data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in im_data.dtype.name:
            datatype = gdal.GDT_Int16
        else:
            datatype = gdal.GDT_Float32

        # 判读数组维数
        if len(im_data.shape) == 3:
            im_bands, im_height, im_width = im_data.shape
        else:
            im_bands, (im_height, im_width) = 1, im_data.shape

        # 创建文件
        driver = gdal.GetDriverByName("GTiff")
        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影

        # 写入RPC参数
        for k in rpc_dict.keys():
            dataset.SetMetadataItem(k, rpc_dict[k], 'RPC')

        if im_bands == 1:
            dataset.GetRasterBand(1).WriteArray(im_data)  # 写入数组数据
        else:
            for i in range(im_bands):
                dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

        del dataset


    def transtif2mask(self,out_tif_path, in_tif_path, threshold):
        """
        :param out_tif_path:输出路径
        :param in_tif_path:输入的路径
        :param threshold:阈值
        """
        im_proj, im_geotrans, im_arr, im_scope = self.read_img(in_tif_path)
        im_arr_mask = (im_arr < threshold).astype(int)
        self.write_img(out_tif_path, im_proj, im_geotrans, im_arr_mask)

    def write_quick_view(self, tif_path, color_img=False, quick_view_path=None):
        """
        生成快视图,默认快视图和影像同路径且同名
        :param tif_path:影像路径
        :param color_img:是否生成随机伪彩色图
        :param quick_view_path:快视图路径
        """
        if quick_view_path is None:
            quick_view_path = os.path.splitext(tif_path)[0]+'.jpg'

        n = self.get_bands(tif_path)
        if n == 1:  # 单波段
            t_data = self.get_data(tif_path)
        else:  # 多波段，转为强度数据
            t_data = self.get_data(tif_path)
            t_data = t_data.astype(float)
            t_data = np.sqrt(t_data[0] ** 2 + t_data[1] ** 2)

        t_r = self.get_img_height(tif_path)
        t_c = self.get_img_width(tif_path)
        if t_r > 10000 or t_c > 10000:
            q_r = int(t_r / 10)
            q_c = int(t_c / 10)
        elif 1024 < t_r < 10000 or 1024 < t_c < 10000:
            if t_r > t_c:
                q_r = 1024
                q_c = int(t_c/t_r * 1024)
            else:
                q_c = 1024
                q_r = int(t_r/t_c * 1024)
        else:
            q_r = t_r
            q_c = t_c

        if color_img is True:
            # 生成伪彩色图
            img = np.zeros((t_r, t_c, 3), dtype=np.uint8)  # (高，宽，维度)
            u = np.unique(t_data)
            for i in u:
                if i != 0:
                    w = np.where(t_data == i)
                    img[w[0], w[1], 0] = np.random.randint(0, 255)  # 随机生成一个0到255之间的整数 可以通过挑参数设定不同的颜色范围
                    img[w[0], w[1], 1] = np.random.randint(0, 255)
                    img[w[0], w[1], 2] = np.random.randint(0, 255)

            img = cv2.resize(img, (q_c, q_r))  # (宽，高)
            cv2.imwrite(quick_view_path, img)
            # cv2.imshow("result4", img)
            # cv2.waitKey(0)
        else:
            # 灰度图
            min = np.nanmin(t_data)
            max = np.nanmax(t_data)
            t_data[np.isnan(t_data)] = max
            if (max - min) < 256:
                t_data = (t_data - min) / (max - min) * 255
            out_img = Image.fromarray(t_data)
            out_img = out_img.resize((q_c, q_r))  # 重采样
            out_img = out_img.convert("L")  # 转换成灰度图
            out_img.save(quick_view_path)

    def limit_field(self, out_path, in_path, min_value, max_value):
        """
        :param out_path:输出路径
        :param in_path:主mask路径，输出影像采用主mask的地理信息
        :param min_value
        :param max_value
        """
        proj = self.get_projection(in_path)
        geotrans = self.get_geotransform(in_path)
        array = self.get_band_array(in_path, 1)
        array[array < min_value] = min_value
        array[array > max_value] = max_value
        self.write_img(out_path, proj, geotrans, array)
        return True


# if __name__ == '__main__':
#     path = r"I:\MicroWorkspace\product\C-SAR\Ortho\Output\GF3B_MYC_QPSI_003581_E120.6_N31.3_20220729_L1A_AHV_L10000073024_RPC\RPC_ori_sim.tif"
#     s = ImageHandler().get_scope_ori_sim(path)
#     print(s)
#     pass